import os
from typing import Optional, Tuple

import networkx as nx
import pandas as pd
from omegaconf import DictConfig


# load data
def load_data(cfg: DictConfig) -> Tuple[nx.Graph, pd.DataFrame, pd.DataFrame]:
    """
    Load data for the given dataset name.

    Args:
        cfg (DictConfig): Config object with all the necessary data.

    Raises:
        FileNotFoundError: If the observational data for the dataset is not found.

    Returns:
        tuple: Tuple containing the causal graph, data and interventional table.
    """
    fully_observable = cfg.fully_observable
    dataset_name = cfg.dataset

    if fully_observable:
        filename = f"{dataset_name}_fully_observable"
    else:
        filename = f"{dataset_name}_partially_observable"

    filename = filename + ".csv"

    data_path = os.path.join(
        os.getcwd(),
        "src",
        "data",
        dataset_name,
    )
    filepath = os.path.join(data_path, filename)

    # Check if the observational data exists
    if os.path.exists(filepath) is False:
        raise FileNotFoundError(f"Data for {dataset_name} not found.")

    data = pd.read_csv(filepath)
    interventional_table = pd.read_csv(os.path.join(data_path, "interventional_table.csv"))

    # Load causal graph
    filename = (
        "ground_truth.gml"
        if fully_observable
        else "ground_truth_partially_observable.gml"
    )
    gml_path = os.path.join(data_path, filename)
    gml_graph: nx.Graph = nx.read_graphml(gml_path)

    for col in data.columns:
        if col not in gml_graph.nodes():
            data = data.drop(col, axis=1)
            print(f"Column {col} not found in the causal graph. Dropping it from the data.")
            
    return gml_graph, data, interventional_table



def load_ground_truth(cfg: DictConfig) -> list[pd.DataFrame, pd.DataFrame]:
    """
    Load ground truth data for the given dataset name.

    Args:
        cfg (DictConfig): Config object with all the necessary data.

    Raises:
        FileNotFoundError: If the ground truth data for the dataset is not found.

    Returns:
        list[pd.DataFrame, pd.DataFrame]: List containing the treated and control ground truth data.
    """
    fully_observable = cfg.fully_observable
    dataset_name = cfg.dataset
    
    if fully_observable:
        filename = f"{dataset_name}_fully_observable"
    else:
        filename = f"{dataset_name}_partially_observable"

    data_path = os.path.join(
        os.getcwd(),
        "src",
        "data",
        dataset_name,
        "ground_truth_data"
    )
    # Get the treatment and control variables, and create the filenames
    ground_truth_list = []
    for treatment_exp in cfg.treatment:
        treatment_str = treatment_exp["treatment"]
        control_str = treatment_exp["control"]

        filename_treated = filename + "_" + treatment_str + ".csv"
        filename_control = filename + "_" + control_str + ".csv"

        filepath_treated = os.path.join(data_path, filename_treated)
        filepath_control = os.path.join(data_path, filename_control)

        # Check if the observational data exists
        if os.path.exists(filepath_treated) is False:
            raise FileNotFoundError(f"Treated ground truth data for {dataset_name} not found.")
        if os.path.exists(filepath_control) is False:
            raise FileNotFoundError(f"Control ground truth data for {dataset_name} not found.")

        ground_truth_treated = pd.read_csv(filepath_treated)
        ground_truth_control = pd.read_csv(filepath_control)

        exp_dict = {
            "treatment": treatment_str,
            "control": control_str,
            "treated_data": ground_truth_treated,
            "control_data": ground_truth_control
        }

        ground_truth_list.append(exp_dict)

    return ground_truth_list
